{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "XscMV3rSoQBB"
      },
      "source": [
        "#### Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.\n",
        "\n",
        "#### Licensed under the Apache License, Version 2.0 (the \"License\");"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "Iie8E5roo8Bw"
      },
      "source": [
        "#### Full license text"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "Tb06_qhlohjx"
      },
      "outputs": [],
      "source": [
        "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "# \n",
        "#     http://www.apache.org/licenses/LICENSE-2.0\n",
        "# \n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "YnJgg8sp15tb"
      },
      "source": [
        "# A (very) basic GAN for MNIST in JAX/Haiku\n",
        "\n",
        "Based on a TensorFlow tutorial written by Mihaela Rosca.\n",
        "\n",
        "Original GAN paper: https://papers.nips.cc/paper/5423-generative-adversarial-nets.pdf"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "ZEPQa39sDEBi"
      },
      "source": [
        "## Imports"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "fe9dyfnF14rh"
      },
      "outputs": [],
      "source": [
        "# Uncomment the line below if running on colab.research.google.com.\n",
        "# !pip install dm-haiku optax\n",
        "\n",
        "import functools\n",
        "from typing import Any, NamedTuple\n",
        "\n",
        "import haiku as hk\n",
        "import jax\n",
        "import optax\n",
        "import jax.numpy as jnp\n",
        "import matplotlib.pyplot as plt\n",
        "import numpy as np\n",
        "import seaborn as sns\n",
        "import tensorflow as tf\n",
        "import tensorflow_datasets as tfds"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "z3RU20wuCSPs"
      },
      "source": [
        "## Define the dataset"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "c0b5luja3zT5"
      },
      "outputs": [],
      "source": [
        "# Download the data once.\n",
        "mnist = tfds.load(\"mnist\")\n",
        "\n",
        "\n",
        "def make_dataset(batch_size, seed=1):\n",
        "  def _preprocess(sample):\n",
        "    # Convert to floats in [0, 1].\n",
        "    image = tf.image.convert_image_dtype(sample[\"image\"], tf.float32)\n",
        "    # Scale the data to [-1, 1] to stabilize training.\n",
        "    return 2.0 * image - 1.0\n",
        "\n",
        "  ds = mnist[\"train\"]\n",
        "  ds = ds.map(map_func=_preprocess,\n",
        "              num_parallel_calls=tf.data.experimental.AUTOTUNE)\n",
        "  ds = ds.cache()\n",
        "  ds = ds.shuffle(10 * batch_size, seed=seed).repeat().batch(batch_size)\n",
        "  return iter(tfds.as_numpy(ds))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "rCIdYs3AH2PE"
      },
      "source": [
        "## Define the model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "rD-EeF4IH4m8"
      },
      "outputs": [],
      "source": [
        "class Generator(hk.Module):\n",
        "  \"\"\"Generator network.\"\"\"\n",
        "\n",
        "  def __init__(self, output_channels=(32, 1), name=None):\n",
        "    super().__init__(name=name)\n",
        "    self.output_channels = output_channels\n",
        "\n",
        "  def __call__(self, x):\n",
        "    \"\"\"Maps noise latents to images.\"\"\"\n",
        "    x = hk.Linear(7 * 7 * 64)(x)\n",
        "    x = jnp.reshape(x, x.shape[:1] + (7, 7, 64))\n",
        "    for output_channels in self.output_channels:\n",
        "      x = jax.nn.relu(x)\n",
        "      x = hk.Conv2DTranspose(output_channels=output_channels,\n",
        "                             kernel_shape=[5, 5],\n",
        "                             stride=2,\n",
        "                             padding=\"SAME\")(x)\n",
        "    # We use a tanh to ensure that the generated samples are in the same\n",
        "    # range as the data.\n",
        "    return jnp.tanh(x)\n",
        "\n",
        "\n",
        "class Discriminator(hk.Module):\n",
        "  \"\"\"Discriminator network.\"\"\"\n",
        "\n",
        "  def __init__(self,\n",
        "               output_channels=(8, 16, 32, 64, 128),\n",
        "               strides=(2, 1, 2, 1, 2),\n",
        "               name=None):\n",
        "    super().__init__(name=name)\n",
        "    self.output_channels = output_channels\n",
        "    self.strides = strides\n",
        "\n",
        "  def __call__(self, x):\n",
        "    \"\"\"Classifies images as real or fake.\"\"\"\n",
        "    for output_channels, stride in zip(self.output_channels, self.strides):\n",
        "      x = hk.Conv2D(output_channels=output_channels,\n",
        "                    kernel_shape=[5, 5],\n",
        "                    stride=stride,\n",
        "                    padding=\"SAME\")(x)\n",
        "      x = jax.nn.leaky_relu(x, negative_slope=0.2)\n",
        "    x = hk.Flatten()(x)\n",
        "    # We have two classes: 0 = input is fake, 1 = input is real.\n",
        "    logits = hk.Linear(2)(x)\n",
        "    return logits\n",
        "\n",
        "\n",
        "def tree_shape(xs):\n",
        "  return jax.tree.map(lambda x: x.shape, xs)\n",
        "\n",
        "\n",
        "def sparse_softmax_cross_entropy(logits, labels):\n",
        "  one_hot_labels = jax.nn.one_hot(labels, logits.shape[-1])\n",
        "  return -jnp.sum(one_hot_labels * jax.nn.log_softmax(logits), axis=-1)\n",
        "\n",
        "\n",
        "class GANTuple(NamedTuple):\n",
        "  gen: Any\n",
        "  disc: Any\n",
        "\n",
        "\n",
        "class GANState(NamedTuple):\n",
        "  params: GANTuple\n",
        "  opt_state: GANTuple\n",
        "\n",
        "\n",
        "class GAN:\n",
        "  \"\"\"A basic GAN.\"\"\"\n",
        "\n",
        "  def __init__(self, num_latents):\n",
        "    self.num_latents = num_latents\n",
        "\n",
        "    # Define the Haiku network transforms.\n",
        "    # We don't use BatchNorm so we don't use `with_state`.\n",
        "    self.gen_transform = hk.without_apply_rng(\n",
        "        hk.transform(lambda *args: Generator()(*args)))\n",
        "    self.disc_transform = hk.without_apply_rng(\n",
        "        hk.transform(lambda *args: Discriminator()(*args)))\n",
        "\n",
        "    # Build the optimizers.\n",
        "    self.optimizers = GANTuple(gen=optax.adam(1e-4, b1=0.5, b2=0.9),\n",
        "                               disc=optax.adam(1e-4, b1=0.5, b2=0.9))\n",
        "\n",
        "  @functools.partial(jax.jit, static_argnums=0)\n",
        "  def initial_state(self, rng, batch):\n",
        "    \"\"\"Returns the initial parameters and optimize states.\"\"\"\n",
        "    # Generate dummy latents for the generator.\n",
        "    dummy_latents = jnp.zeros((batch.shape[0], self.num_latents))\n",
        "\n",
        "    # Get initial network parameters.\n",
        "    rng_gen, rng_disc = jax.random.split(rng)\n",
        "    params = GANTuple(gen=self.gen_transform.init(rng_gen, dummy_latents),\n",
        "                      disc=self.disc_transform.init(rng_disc, batch))\n",
        "    print(\"Generator: \\n\\n{}\\n\".format(tree_shape(params.gen)))\n",
        "    print(\"Discriminator: \\n\\n{}\\n\".format(tree_shape(params.disc)))\n",
        "\n",
        "    # Initialize the optimizers.\n",
        "    opt_state = GANTuple(gen=self.optimizers.gen.init(params.gen),\n",
        "                         disc=self.optimizers.disc.init(params.disc))\n",
        "\n",
        "    return GANState(params=params, opt_state=opt_state)\n",
        "\n",
        "  def sample(self, rng, gen_params, num_samples):\n",
        "    \"\"\"Generates images from noise latents.\"\"\"\n",
        "    latents = jax.random.normal(rng, shape=(num_samples, self.num_latents))\n",
        "    return self.gen_transform.apply(gen_params, latents)\n",
        "\n",
        "  def gen_loss(self, gen_params, rng, disc_params, batch):\n",
        "    \"\"\"Generator loss.\"\"\"\n",
        "    # Sample from the generator.\n",
        "    fake_batch = self.sample(rng, gen_params, num_samples=batch.shape[0])\n",
        "\n",
        "    # Evaluate using the discriminator. Recall class 1 is real.\n",
        "    fake_logits = self.disc_transform.apply(disc_params, fake_batch)\n",
        "    fake_probs = jax.nn.softmax(fake_logits)[:, 1]\n",
        "    loss = -jnp.log(fake_probs)\n",
        "\n",
        "    return jnp.mean(loss)\n",
        "\n",
        "  def disc_loss(self, disc_params, rng, gen_params, batch):\n",
        "    \"\"\"Discriminator loss.\"\"\"\n",
        "    # Sample from the generator.\n",
        "    fake_batch = self.sample(rng, gen_params, num_samples=batch.shape[0])\n",
        "\n",
        "    # For efficiency we process both the real and fake data in one pass.\n",
        "    real_and_fake_batch = jnp.concatenate([batch, fake_batch], axis=0)\n",
        "    real_and_fake_logits = self.disc_transform.apply(disc_params,\n",
        "                                                     real_and_fake_batch)\n",
        "    real_logits, fake_logits = jnp.split(real_and_fake_logits, 2, axis=0)\n",
        "\n",
        "    # Class 1 is real.\n",
        "    real_labels = jnp.ones((batch.shape[0],), dtype=jnp.int32)\n",
        "    real_loss = sparse_softmax_cross_entropy(real_logits, real_labels)\n",
        "\n",
        "    # Class 0 is fake.\n",
        "    fake_labels = jnp.zeros((batch.shape[0],), dtype=jnp.int32)\n",
        "    fake_loss = sparse_softmax_cross_entropy(fake_logits, fake_labels)\n",
        "\n",
        "    return jnp.mean(real_loss + fake_loss)\n",
        "\n",
        "  @functools.partial(jax.jit, static_argnums=0)\n",
        "  def update(self, rng, gan_state, batch):\n",
        "    \"\"\"Performs a parameter update.\"\"\"\n",
        "    rng, rng_gen, rng_disc = jax.random.split(rng, 3)\n",
        "\n",
        "    # Update the discriminator.\n",
        "    disc_loss, disc_grads = jax.value_and_grad(self.disc_loss)(\n",
        "        gan_state.params.disc,\n",
        "        rng_disc,\n",
        "        gan_state.params.gen,\n",
        "        batch)\n",
        "    disc_update, disc_opt_state = self.optimizers.disc.update(\n",
        "        disc_grads, gan_state.opt_state.disc)\n",
        "    disc_params = optax.apply_updates(gan_state.params.disc, disc_update)\n",
        "\n",
        "    # Update the generator.\n",
        "    gen_loss, gen_grads = jax.value_and_grad(self.gen_loss)(\n",
        "        gan_state.params.gen,\n",
        "        rng_gen,\n",
        "        gan_state.params.disc,\n",
        "        batch)\n",
        "    gen_update, gen_opt_state = self.optimizers.gen.update(\n",
        "        gen_grads, gan_state.opt_state.gen)\n",
        "    gen_params = optax.apply_updates(gan_state.params.gen, gen_update)\n",
        "\n",
        "    params = GANTuple(gen=gen_params, disc=disc_params)\n",
        "    opt_state = GANTuple(gen=gen_opt_state, disc=disc_opt_state)\n",
        "    gan_state = GANState(params=params, opt_state=opt_state)\n",
        "    log = {\n",
        "        \"gen_loss\": gen_loss,\n",
        "        \"disc_loss\": disc_loss,\n",
        "    }\n",
        "\n",
        "    return rng, gan_state, log"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "hvYgaHYu_PPz"
      },
      "source": [
        "## Train the model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "8ZC_oHCCKX2J"
      },
      "outputs": [],
      "source": [
        "#@title {vertical-output: true}\n",
        "\n",
        "num_steps = 20001\n",
        "log_every = num_steps // 100\n",
        "\n",
        "# Let's see what hardware we're working with. The training takes a few\n",
        "# minutes on a GPU, a bit longer on CPU.\n",
        "print(f\"Number of devices: {jax.device_count()}\")\n",
        "print(\"Device:\", jax.devices()[0].device_kind)\n",
        "print(\"\")\n",
        "\n",
        "# Make the dataset.\n",
        "dataset = make_dataset(batch_size=64)\n",
        "\n",
        "# The model.\n",
        "gan = GAN(num_latents=20)\n",
        "\n",
        "# Top-level RNG.\n",
        "rng = jax.random.PRNGKey(1729)\n",
        "\n",
        "# Initialize the network and optimizer.\n",
        "rng, rng1 = jax.random.split(rng)\n",
        "gan_state = gan.initial_state(rng1, next(dataset))\n",
        "\n",
        "steps = []\n",
        "gen_losses = []\n",
        "disc_losses = []\n",
        "\n",
        "for step in range(num_steps):\n",
        "  rng, gan_state, log = gan.update(rng, gan_state, next(dataset))\n",
        "\n",
        "  # Log the losses.\n",
        "  if step % log_every == 0:\n",
        "    # It's important to call `device_get` here so we don't take up device\n",
        "    # memory by saving the losses.\n",
        "    log = jax.device_get(log)\n",
        "    gen_loss = log[\"gen_loss\"]\n",
        "    disc_loss = log[\"disc_loss\"]\n",
        "    print(f\"Step {step}: \"\n",
        "          f\"gen_loss = {gen_loss:.3f}, disc_loss = {disc_loss:.3f}\")\n",
        "    steps.append(step)\n",
        "    gen_losses.append(gen_loss)\n",
        "    disc_losses.append(disc_loss)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "RpAYrshvKlsw"
      },
      "source": [
        "## Visualize the losses\n",
        "Unlike losses for classifiers or VAEs, GAN losses do not decrease steadily, instead going up and down depending on the training dynamics."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "7QkzckE3Komt"
      },
      "outputs": [],
      "source": [
        "sns.set_style(\"whitegrid\")\n",
        "\n",
        "fig, axes = plt.subplots(1, 2, figsize=(20, 6))\n",
        "\n",
        "# Plot the discriminator loss.\n",
        "axes[0].plot(steps, disc_losses, \"-\")\n",
        "axes[0].plot(steps, np.log(2) * np.ones_like(steps), \"r--\",\n",
        "             label=\"Discriminator is being fooled\")\n",
        "axes[0].legend(fontsize=20)\n",
        "axes[0].set_title(\"Discriminator loss\", fontsize=20)\n",
        "\n",
        "# Plot the generator loss.\n",
        "axes[1].plot(steps, gen_losses, '-')\n",
        "axes[1].set_title(\"Generator loss\", fontsize=20);"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "GtYGGrmgKoyd"
      },
      "source": [
        "## Visualize samples"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "ZFr7_quYK29E"
      },
      "outputs": [],
      "source": [
        "#@title {vertical-output: true}\n",
        "\n",
        "def make_grid(samples, num_cols=8, rescale=True):\n",
        "  batch_size, height, width = samples.shape\n",
        "  assert batch_size % num_cols == 0\n",
        "  num_rows = batch_size // num_cols\n",
        "  # We want samples.shape == (height * num_rows, width * num_cols).\n",
        "  samples = samples.reshape(num_rows, num_cols, height, width)\n",
        "  samples = samples.swapaxes(1, 2)\n",
        "  samples = samples.reshape(height * num_rows, width * num_cols)\n",
        "  return samples\n",
        "\n",
        "\n",
        "# Generate samples from the trained generator.\n",
        "rng = jax.random.PRNGKey(12)\n",
        "samples = gan.sample(rng, gan_state.params.gen, num_samples=64)\n",
        "samples = jax.device_get(samples)\n",
        "samples = samples.squeeze(axis=-1)\n",
        "# Our model outputs values in [-1, 1] so scale it back to [0, 1].\n",
        "samples = (samples + 1.0) / 2.0\n",
        "\n",
        "plt.gray()\n",
        "plt.axis(\"off\")\n",
        "samples_grid = make_grid(samples)\n",
        "plt.imshow(samples_grid);"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [
        "Iie8E5roo8Bw"
      ],
      "name": "MNIST GAN.ipynb",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
